Skip to content

Comments

⚡️ Speed up method JavaAssertTransformer._find_top_level_arg_node by 11% in PR #1199 (omni-java)#1628

Open
codeflash-ai[bot] wants to merge 4 commits intoomni-javafrom
codeflash/optimize-pr1199-2026-02-21T00.13.51
Open

⚡️ Speed up method JavaAssertTransformer._find_top_level_arg_node by 11% in PR #1199 (omni-java)#1628
codeflash-ai[bot] wants to merge 4 commits intoomni-javafrom
codeflash/optimize-pr1199-2026-02-21T00.13.51

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 21, 2026

⚡️ This pull request contains optimizations for PR #1199

If you approve this dependent PR, these changes will be merged into the original PR branch omni-java.

This PR will be automatically closed if the original PR is merged.


📄 11% (0.11x) speedup for JavaAssertTransformer._find_top_level_arg_node in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 1.52 milliseconds 1.37 milliseconds (best of 39 runs)

📝 Explanation and details

The optimization achieves a 10% runtime improvement by restructuring the tree traversal loop in _find_top_level_arg_node to reduce redundant attribute accesses and improve control flow efficiency.

Key optimizations:

  1. Eliminated redundant current.parent checks: The original code checked while current.parent is not None and then immediately accessed parent = current.parent. The optimized version uses while True with an explicit if parent is None: return None check after assignment, removing the double attribute access on every iteration (4017 iterations in the profile).

  2. Cached parent.type in a local variable: Instead of accessing parent.type twice (once in the compound condition parent.type == "argument_list" and parent.parent is not None, and potentially again in comparisons), the optimized code stores it in parent_type. This reduces attribute lookups, which in Python involve dictionary lookups in the object's __dict__.

  3. Separated compound boolean conditions: The original code used if parent.type == "argument_list" and parent.parent is not None, which evaluates both parent.type and parent.parent on every check. The optimized version first checks if parent_type == "argument_list", and only then accesses parent.parent. This improves short-circuit evaluation efficiency and makes attribute access patterns more predictable.

  4. Streamlined parent navigation: Changed current = current.parent to current = parent, reusing the already-fetched parent reference instead of re-accessing the attribute.

Performance impact based on test results:

The optimization particularly benefits the test_many_iterations_stability case (1000 repeated calls), which shows 9.64% improvement (1.40ms → 1.27ms). This demonstrates that the per-iteration savings compound significantly in loops. The slight regressions (2-5%) in some smaller test cases are likely measurement noise, as the overall runtime metric shows a solid 10% gain across the full workload.

Why this matters for Java code analysis:

The _find_top_level_arg_node method is part of assertion removal logic in Java test transformation. While function_references are unavailable, the method name and context suggest it's called during AST traversal of test methods, potentially multiple times per test file. The 10% speedup means faster test code analysis, which is valuable in CI/CD pipelines or IDE integrations where developers need rapid feedback.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2053 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 94.4%
🌀 Click to see Generated Regression Tests
from types import \
    SimpleNamespace  # small helper for creating node-like structures

# imports
import pytest  # used for our unit tests
from codeflash.languages.java.parser import JavaAnalyzer
from codeflash.languages.java.remove_asserts import JavaAssertTransformer

# Helper utilities for building simple node graphs that mimic the minimal
# interface accessed by JavaAssertTransformer._find_top_level_arg_node.
#
# Note: These are lightweight, structural stand-ins for tree-sitter Node
# instances. They expose only the attributes and methods the implementation
# reads: .parent, .type, child_by_field_name(name) and, for name nodes,
# .start_byte/.end_byte so JavaAnalyzer.get_node_text can slice wrapper bytes.
def _make_name_node(start: int, text: bytes) -> SimpleNamespace:
    # A node representing an identifier with byte range [start, start+len(text)]
    return SimpleNamespace(start_byte=start, end_byte=start + len(text), parent=None)

def _make_method_invocation(name_node: SimpleNamespace, parent: SimpleNamespace | None = None) -> SimpleNamespace:
    # A node representing a method_invocation. child_by_field_name('name') returns name_node.
    def child_by_field_name(field: str):
        return name_node if field == "name" else None
    node = SimpleNamespace(type="method_invocation", parent=parent, child_by_field_name=child_by_field_name)
    # keep name_node.parent consistent with AST shape if used
    if name_node is not None:
        name_node.parent = node
    return node

def _make_argument_list(parent: SimpleNamespace | None = None) -> SimpleNamespace:
    # A node representing an argument_list with a given parent.
    return SimpleNamespace(type="argument_list", parent=parent)

def test_direct_argument_of__d_returns_none_simple():
    # Basic test: target node is a direct argument of _d(...) => should return None.
    #
    # Structure:
    #   target_node
    #     parent -> arg_list_d
    #       parent -> method_invocation_d (name="_d")
    wrapper = b"class _D { void _m() { _d(arg); } }"
    # place the name "_d" somewhere in wrapper bytes
    idx_d = wrapper.find(b"_d")
    name_node_d = _make_name_node(idx_d, b"_d")
    method_invocation_d = _make_method_invocation(name_node_d, parent=None)
    arg_list_d = _make_argument_list(parent=method_invocation_d)
    # a target node that is directly inside the argument list (no intermediate regular call)
    target_node = SimpleNamespace(parent=arg_list_d)

    transformer = JavaAssertTransformer(function_name="irrelevant", analyzer=JavaAnalyzer())
    # When the target is a direct argument, _find_top_level_arg_node must return None
    codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper); result = codeflash_output # 2.65μs -> 2.50μs (5.99% faster)

def test_nested_regular_call_returns_top_level_expression():
    # Basic test: nested regular call should produce the immediate method_invocation node
    # that represents the nested call (the top-level expression containing it).
    #
    # We simulate: _d(..., <method_invocation_foo>(...))
    # and expect _find_top_level_arg_node to return the method_invocation_foo node.
    #
    # AST shape:
    # target_node
    #   parent -> arg_list_foo
    #     parent -> method_invocation_foo
    #       parent -> arg_list_d
    #         parent -> method_invocation_d (name="_d")
    wrapper = b"class _D { void _m() { _d(55, foo(10)); } }"
    # find 'foo' and '_d' positions for get_node_text
    idx_foo = wrapper.find(b"foo")
    idx_d = wrapper.find(b"_d")

    # name nodes
    name_node_foo = _make_name_node(idx_foo, b"foo")
    name_node_d = _make_name_node(idx_d, b"_d")

    # construct nodes bottom-up
    method_invocation_d = _make_method_invocation(name_node_d, parent=None)
    arg_list_d = _make_argument_list(parent=method_invocation_d)

    method_invocation_foo = _make_method_invocation(name_node_foo, parent=arg_list_d)
    arg_list_foo = _make_argument_list(parent=method_invocation_foo)

    # make the target node be something inside the foo argument list (e.g., a literal or expression)
    target_node = SimpleNamespace(parent=arg_list_foo)

    transformer = JavaAssertTransformer(function_name="irrelevant", analyzer=JavaAnalyzer())
    codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper); result = codeflash_output # 3.47μs -> 3.64μs (4.70% slower)

def test_assertion_calls_are_ignored():
    # Edge test: nested assertion calls (names starting with "assert") should not mark
    # passed_through_regular_call and thus should be treated as direct.
    #
    # Structure: _d( ..., assertSomething(inner(...)) )
    wrapper = b"class _D { void _m() { _d( assertEquals(inner(5)) ); } }"
    idx_assert = wrapper.find(b"assertEquals")
    idx_d = wrapper.find(b"_d")
    idx_inner = wrapper.find(b"inner")

    name_node_assert = _make_name_node(idx_assert, b"assertEquals")
    name_node_inner = _make_name_node(idx_inner, b"inner")
    name_node_d = _make_name_node(idx_d, b"_d")

    method_invocation_d = _make_method_invocation(name_node_d, parent=None)
    arg_list_d = _make_argument_list(parent=method_invocation_d)

    # assertSomething is the outer call in the argument list of _d
    method_invocation_assert = _make_method_invocation(name_node_assert, parent=arg_list_d)
    arg_list_assert = _make_argument_list(parent=method_invocation_assert)

    # inner(...) is nested inside the assert call
    method_invocation_inner = _make_method_invocation(name_node_inner, parent=arg_list_assert)
    arg_list_inner = _make_argument_list(parent=method_invocation_inner)

    # target inside inner(...)
    target_node = SimpleNamespace(parent=arg_list_inner)

    transformer = JavaAssertTransformer(function_name="irrelevant", analyzer=JavaAnalyzer())
    codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper); result = codeflash_output # 4.28μs -> 4.39μs (2.51% slower)

def test_method_without_name_is_safely_ignored():
    # Edge test: If a method_invocation node does not have a 'name' child (child_by_field_name returns None),
    # the algorithm should not break and should continue climbing the tree.
    #
    # Structure: foo(...) (no name node) nested inside _d -> since there is no name information,
    # it cannot be detected as a regular call -> should behave as if direct (None).
    wrapper = b"class _D { void _m() { _d( mysterious_call(5) ); } }"
    idx_d = wrapper.find(b"_d")

    # Create a method_invocation that returns None for child_by_field_name
    def missing_child(field: str):
        return None

    method_invocation_d = _make_method_invocation(_make_name_node(idx_d, b"_d"), parent=None)
    arg_list_d = _make_argument_list(parent=method_invocation_d)

    # "mysterious_call" method invocation but with no accessible name node
    method_invocation_mystery = SimpleNamespace(type="method_invocation", parent=arg_list_d, child_by_field_name=missing_child)
    arg_list_mystery = _make_argument_list(parent=method_invocation_mystery)

    target_node = SimpleNamespace(parent=arg_list_mystery)

    transformer = JavaAssertTransformer(function_name="irrelevant", analyzer=JavaAnalyzer())
    codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper); result = codeflash_output # 2.42μs -> 2.51μs (3.58% slower)

def test_many_iterations_stability():
    # Large-scale repeated invocation to ensure deterministic behavior across many calls.
    #
    # Reuse the simple nested-regular-call case and call the function 1000 times.
    wrapper = b"class _D { void _m() { _d(55, foo(10)); } }"
    idx_foo = wrapper.find(b"foo")
    idx_d = wrapper.find(b"_d")

    name_node_foo = _make_name_node(idx_foo, b"foo")
    name_node_d = _make_name_node(idx_d, b"_d")

    method_invocation_d = _make_method_invocation(name_node_d, parent=None)
    arg_list_d = _make_argument_list(parent=method_invocation_d)

    method_invocation_foo = _make_method_invocation(name_node_foo, parent=arg_list_d)
    arg_list_foo = _make_argument_list(parent=method_invocation_foo)

    target_node = SimpleNamespace(parent=arg_list_foo)

    transformer = JavaAssertTransformer(function_name="irrelevant", analyzer=JavaAnalyzer())

    # Run the same call many times to check for any accidental statefulness or mutation
    for _ in range(1000):
        codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper); result = codeflash_output # 1.40ms -> 1.27ms (9.64% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
from codeflash.languages.java.remove_asserts import JavaAssertTransformer
from tree_sitter import Language, Node, Parser

# Helper function to find all nodes of a specific type
def find_all_nodes_by_type(node: Node, node_type: str) -> list[Node]:
    """Recursively find all nodes with a specific type."""
    result = []
    if node.type == node_type:
        result.append(node)
    for child in node.children:
        result.extend(find_all_nodes_by_type(child, node_type))
    return result

def test_basic_direct_target_no_nesting():
    """Test that a direct target (not nested in a function call) returns None."""
    # Create a JavaAssertTransformer instance
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper code: class _D { void _m() { _d(simple_var); } }
    wrapper_code = 'class _D { void _m() { _d(simple_var); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    # Parse the wrapper code
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the identifier node for 'simple_var'
    identifiers = find_all_nodes_by_type(tree.root_node, "identifier")
    simple_var_node = None
    for ident in identifiers:
        text = wrapper_bytes[ident.start_byte : ident.end_byte].decode("utf8")
        if text == "simple_var":
            simple_var_node = ident
            break
    
    # Call _find_top_level_arg_node with the direct target
    codeflash_output = transformer._find_top_level_arg_node(simple_var_node, wrapper_bytes); result = codeflash_output

def test_basic_nested_in_regular_call():
    """Test that a target nested in a regular function call returns the top-level arg."""
    # Create a JavaAssertTransformer instance
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper code: class _D { void _m() { _d(obj.method()); } }
    wrapper_code = 'class _D { void _m() { _d(obj.method()); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    # Parse the wrapper code
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the method_invocation node for 'obj.method()'
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    method_node = None
    for method in method_invocations:
        text = wrapper_bytes[method.start_byte : method.end_byte].decode("utf8")
        if "method()" in text:
            method_node = method
            break
    
    # Call _find_top_level_arg_node with the nested call
    codeflash_output = transformer._find_top_level_arg_node(method_node, wrapper_bytes); result = codeflash_output

def test_basic_assertion_method_not_counted_as_nesting():
    """Test that passing through an assertion method does not count as nesting."""
    # Create a JavaAssertTransformer instance
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper code with an assertion method: _d(assertEquals(...))
    wrapper_code = 'class _D { void _m() { _d(assertEquals(5, 5)); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    # Parse the wrapper code
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the assertEquals method invocation
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    assert_method_node = None
    for method in method_invocations:
        name = method.child_by_field_name("name")
        if name:
            method_name = wrapper_bytes[name.start_byte : name.end_byte].decode("utf8")
            if method_name == "assertEquals":
                assert_method_node = method
                break
    
    # Call _find_top_level_arg_node with the assertion method
    codeflash_output = transformer._find_top_level_arg_node(assert_method_node, wrapper_bytes); result = codeflash_output

def test_basic_multiple_levels_of_nesting():
    """Test multiple levels of nesting returns the immediate parent arg."""
    # Create a JavaAssertTransformer instance
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper code: _d(obj1.method1(obj2.method2()))
    wrapper_code = 'class _D { void _m() { _d(obj1.method1(obj2.method2())); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    # Parse the wrapper code
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the innermost method invocation (method2)
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    method2_node = None
    for method in method_invocations:
        text = wrapper_bytes[method.start_byte : method.end_byte].decode("utf8")
        if text == "obj2.method2()":
            method2_node = method
            break
    
    # Call _find_top_level_arg_node with the innermost call
    codeflash_output = transformer._find_top_level_arg_node(method2_node, wrapper_bytes); result = codeflash_output

def test_edge_target_node_is_none():
    """Test behavior when target_node is None (edge case)."""
    transformer = JavaAssertTransformer("assertEquals")
    wrapper_bytes = b'class _D { void _m() { _d(5); } }'
    
    # Calling with None should raise an AttributeError or similar
    with pytest.raises((AttributeError, TypeError)):
        transformer._find_top_level_arg_node(None, wrapper_bytes) # 3.12μs -> 3.19μs (2.20% slower)

def test_edge_wrapper_bytes_empty():
    """Test behavior with empty wrapper bytes."""
    transformer = JavaAssertTransformer("assertEquals")
    wrapper_bytes = b''
    
    # Parse valid wrapper code to get a target node
    valid_code = 'class _D { void _m() { _d(5); } }'
    valid_bytes = valid_code.encode("utf8")
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(valid_bytes)
    
    # Find an identifier node
    identifiers = find_all_nodes_by_type(tree.root_node, "identifier")
    target_node = identifiers[0] if identifiers else None
    
    if target_node:
        # Calling with empty wrapper_bytes may raise an error
        with pytest.raises((IndexError, UnicodeDecodeError)):
            transformer._find_top_level_arg_node(target_node, wrapper_bytes)

def test_edge_single_character_wrapper():
    """Test with minimal wrapper bytes."""
    transformer = JavaAssertTransformer("assertEquals")
    # Create minimal but valid wrapper
    wrapper_code = 'class _D { void _m() { _d(x); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the identifier 'x'
    identifiers = find_all_nodes_by_type(tree.root_node, "identifier")
    x_node = None
    for ident in identifiers:
        text = wrapper_bytes[ident.start_byte : ident.end_byte].decode("utf8")
        if text == "x":
            x_node = ident
            break
    
    if x_node:
        codeflash_output = transformer._find_top_level_arg_node(x_node, wrapper_bytes); result = codeflash_output

def test_edge_no_parent_node():
    """Test with a root node that has no parent."""
    transformer = JavaAssertTransformer("assertEquals")
    wrapper_code = 'class _D { void _m() { _d(5); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Use the root node which has no parent
    root_node = tree.root_node
    codeflash_output = transformer._find_top_level_arg_node(root_node, wrapper_bytes); result = codeflash_output

def test_edge_method_name_starting_with_assert():
    """Test that methods starting with 'assert' are treated as assertions."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper code with assertNotNull which starts with 'assert'
    wrapper_code = 'class _D { void _m() { _d(assertNotNull(value)); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the assertNotNull method invocation
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    assert_method_node = None
    for method in method_invocations:
        name = method.child_by_field_name("name")
        if name:
            method_name = wrapper_bytes[name.start_byte : name.end_byte].decode("utf8")
            if method_name == "assertNotNull":
                assert_method_node = method
                break
    
    if assert_method_node:
        codeflash_output = transformer._find_top_level_arg_node(assert_method_node, wrapper_bytes); result = codeflash_output

def test_edge_deeply_nested_target():
    """Test with deeply nested expressions."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create deeply nested wrapper code
    wrapper_code = 'class _D { void _m() { _d(a.b(c.d(e.f(g.h())))); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the innermost method invocation (h())
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    innermost_method = None
    for method in method_invocations:
        text = wrapper_bytes[method.start_byte : method.end_byte].decode("utf8")
        if text == "g.h()":
            innermost_method = method
            break
    
    if innermost_method:
        codeflash_output = transformer._find_top_level_arg_node(innermost_method, wrapper_bytes); result = codeflash_output

def test_edge_method_invocation_without_name():
    """Test with a method invocation that somehow lacks a name field."""
    transformer = JavaAssertTransformer("assertEquals")
    wrapper_code = 'class _D { void _m() { _d(5); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find a simple number literal
    identifiers = find_all_nodes_by_type(tree.root_node, "number")
    if identifiers:
        codeflash_output = transformer._find_top_level_arg_node(identifiers[0], wrapper_bytes); result = codeflash_output

def test_edge_wrapper_with_special_characters():
    """Test wrapper code with special characters in method names."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper with method names containing underscores
    wrapper_code = 'class _D { void _m() { _d(_obj._method()); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the _method invocation
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    inner_method = None
    for method in method_invocations:
        text = wrapper_bytes[method.start_byte : method.end_byte].decode("utf8")
        if "_method()" in text:
            inner_method = method
            break
    
    if inner_method:
        codeflash_output = transformer._find_top_level_arg_node(inner_method, wrapper_bytes); result = codeflash_output

def test_large_scale_very_deeply_nested_expression():
    """Test with very deeply nested function calls (100+ levels)."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Build a deeply nested expression
    depth = 50
    inner = "a()"
    for i in range(depth):
        inner = f"m{i}({inner})"
    
    wrapper_code = f'class _D {{ void _m() {{ _d({inner}); }} }}'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find the innermost method invocation (a())
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    innermost = None
    for method in method_invocations:
        text = wrapper_bytes[method.start_byte : method.end_byte].decode("utf8")
        if text == "a()":
            innermost = method
            break
    
    if innermost:
        # This should still complete without timeout
        codeflash_output = transformer._find_top_level_arg_node(innermost, wrapper_bytes); result = codeflash_output

def test_large_scale_many_sibling_methods():
    """Test with many sibling method calls to stress traversal."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create wrapper with many method calls
    methods = ", ".join([f"m{i}()" for i in range(100)])
    wrapper_code = f'class _D {{ void _m() {{ _d({methods}); }} }}'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find method invocations
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    
    # Test with first few methods
    for method in method_invocations[:5]:
        name = method.child_by_field_name("name")
        if name:
            method_name = wrapper_bytes[name.start_byte : name.end_byte].decode("utf8")
            if method_name.startswith("m"):
                codeflash_output = transformer._find_top_level_arg_node(method, wrapper_bytes); result = codeflash_output

def test_large_scale_long_wrapper_bytes():
    """Test with a very large wrapper bytes sequence."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Create a wrapper with a large string literal
    large_string = "x" * 10000
    wrapper_code = f'class _D {{ void _m() {{ _d("{large_string}"); }} }}'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find a string literal node
    strings = find_all_nodes_by_type(tree.root_node, "string_literal")
    if strings:
        codeflash_output = transformer._find_top_level_arg_node(strings[0], wrapper_bytes); result = codeflash_output

def test_large_scale_many_assertion_methods():
    """Test traversal through multiple assertion methods."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Build nested assertions
    code = "_d(assertTrue(assertNotNull(assertEquals(5, 5))))"
    wrapper_code = f'class _D {{ void _m() {{ {code}; }} }}'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find all method invocations and test each
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    
    for method in method_invocations:
        name = method.child_by_field_name("name")
        if name:
            codeflash_output = transformer._find_top_level_arg_node(method, wrapper_bytes); result = codeflash_output

def test_large_scale_complex_mixed_nesting():
    """Test complex expression with mixed assertion and regular calls."""
    transformer = JavaAssertTransformer("assertEquals")
    
    # Mix assertions with regular method calls
    wrapper_code = '''class _D { 
        void _m() { 
            _d(obj.filter(x -> assertEquals(x.getValue(), 5))
                  .map(y -> y.process())
                  .orElse(defaultValue())); 
        } 
    }'''
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    # Find method invocations
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    
    # Test a few methods without errors
    for method in method_invocations[:3]:
        try:
            codeflash_output = transformer._find_top_level_arg_node(method, wrapper_bytes); result = codeflash_output
        except (AttributeError, IndexError):
            # Some intermediate nodes might not have proper structure
            pass

def test_large_scale_array_of_transformers():
    """Test creating many transformer instances and using them."""
    transformers = []
    for i in range(500):
        transformer = JavaAssertTransformer(f"method_{i}", f"pkg.Class{i}")
        transformers.append(transformer)
    
    # Test a few with wrapper code
    wrapper_code = 'class _D { void _m() { _d(5); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    identifiers = find_all_nodes_by_type(tree.root_node, "identifier")
    if identifiers:
        for transformer in transformers[:10]:
            codeflash_output = transformer._find_top_level_arg_node(identifiers[0], wrapper_bytes); result = codeflash_output

def test_large_scale_repeated_calls_same_target():
    """Test calling _find_top_level_arg_node repeatedly with same target."""
    transformer = JavaAssertTransformer("assertEquals")
    wrapper_code = 'class _D { void _m() { _d(obj.method()); } }'
    wrapper_bytes = wrapper_code.encode("utf8")
    
    parser = Parser()
    java_language = Language("build/my-languages.so", "java")
    parser.set_language(java_language)
    tree = parser.parse(wrapper_bytes)
    
    method_invocations = find_all_nodes_by_type(tree.root_node, "method_invocation")
    if method_invocations:
        target_node = method_invocations[0]
        
        # Call the function many times with same target
        results = []
        for _ in range(1000):
            codeflash_output = transformer._find_top_level_arg_node(target_node, wrapper_bytes); result = codeflash_output
            results.append(result)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1199-2026-02-21T00.13.51 and push.

Codeflash Static Badge

The optimization achieves a **10% runtime improvement** by restructuring the tree traversal loop in `_find_top_level_arg_node` to reduce redundant attribute accesses and improve control flow efficiency.

**Key optimizations:**

1. **Eliminated redundant `current.parent` checks**: The original code checked `while current.parent is not None` and then immediately accessed `parent = current.parent`. The optimized version uses `while True` with an explicit `if parent is None: return None` check after assignment, removing the double attribute access on every iteration (4017 iterations in the profile).

2. **Cached `parent.type` in a local variable**: Instead of accessing `parent.type` twice (once in the compound condition `parent.type == "argument_list" and parent.parent is not None`, and potentially again in comparisons), the optimized code stores it in `parent_type`. This reduces attribute lookups, which in Python involve dictionary lookups in the object's `__dict__`.

3. **Separated compound boolean conditions**: The original code used `if parent.type == "argument_list" and parent.parent is not None`, which evaluates both `parent.type` and `parent.parent` on every check. The optimized version first checks `if parent_type == "argument_list"`, and only then accesses `parent.parent`. This improves short-circuit evaluation efficiency and makes attribute access patterns more predictable.

4. **Streamlined parent navigation**: Changed `current = current.parent` to `current = parent`, reusing the already-fetched parent reference instead of re-accessing the attribute.

**Performance impact based on test results:**

The optimization particularly benefits the `test_many_iterations_stability` case (1000 repeated calls), which shows **9.64% improvement (1.40ms → 1.27ms)**. This demonstrates that the per-iteration savings compound significantly in loops. The slight regressions (2-5%) in some smaller test cases are likely measurement noise, as the overall runtime metric shows a solid 10% gain across the full workload.

**Why this matters for Java code analysis:**

The `_find_top_level_arg_node` method is part of assertion removal logic in Java test transformation. While function_references are unavailable, the method name and context suggest it's called during AST traversal of test methods, potentially multiple times per test file. The 10% speedup means faster test code analysis, which is valuable in CI/CD pipelines or IDE integrations where developers need rapid feedback.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 21, 2026
@codeflash-ai codeflash-ai bot mentioned this pull request Feb 21, 2026
@claude claude bot force-pushed the codeflash/optimize-pr1199-2026-02-21T00.13.51 branch from 73d3b64 to 4258659 Compare February 21, 2026 00:17
@claude claude bot force-pushed the codeflash/optimize-pr1199-2026-02-21T00.13.51 branch from 4258659 to 82fea16 Compare February 21, 2026 00:18
The optimized code achieves a **124% speedup (482μs → 214μs)** by eliminating redundant module imports through two key optimizations:

**Primary Optimization: Hoisting Imports to Module Scope**
Moving `contextlib`, `importlib`, and `sys` imports from inside the function to module-level eliminates ~61μs of repeated import overhead. The line profiler shows the original code spent time importing these modules on every cold call (35μs + 26μs), which adds up across multiple invocations.

**Secondary Optimization: sys.modules Cache Check**
The most impactful change is checking `if name in sys.modules` before calling `importlib.import_module(name)`. The profiler reveals that subsequent calls were still invoking `importlib.import_module()` even for already-loaded modules. By checking the cache first, the optimized version:
- Avoids 228 out of 231 redundant import_module calls (see optimized profiler: 228 continues vs 3 actual imports)
- Reduces from 462 total contextlib.suppress operations to just 6
- Trades expensive import_module calls (~82-256ms each) for fast dictionary lookups (~320ns each)

**Loop Refactoring**
Replacing three separate `with contextlib.suppress` blocks with a loop over a tuple makes the code more maintainable while enabling the cache check optimization. The loop itself adds negligible overhead (68μs total).

**Test Results Validation**
The annotated tests show consistent 400-600% speedups in cold-path scenarios (when modules need registration), with the optimization being most effective when:
- Functions are called multiple times after initial registration (e.g., `test_ensure_languages_registered_large_scale_repeated_calls`)
- Multiple sequential resets occur (e.g., `test_ensure_languages_registered_multiple_sequential_resets` shows 548% improvement)
- The function is in a hot path with repeated calls (several tests show sub-microsecond improvement after first call)

The optimization maintains correctness by preserving the ImportError suppression behavior and idempotency guarantees, while dramatically reducing runtime for the common case where language modules are already loaded in sys.modules.
@codeflash-ai
Copy link
Contributor Author

codeflash-ai bot commented Feb 21, 2026

⚡️ Codeflash found optimizations for this PR

📄 125% (1.25x) speedup for _ensure_languages_registered in codeflash/languages/registry.py

⏱️ Runtime : 482 microseconds 214 microseconds (best of 250 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch codeflash/optimize-pr1199-2026-02-21T00.13.51).

Static Badge

@claude
Copy link
Contributor

claude bot commented Feb 21, 2026

PR Review Summary

Prek Checks

Fixed - 2 issues resolved and pushed:

  • codeflash/languages/java/remove_asserts.py: Fixed 2x trailing whitespace (W293)
  • codeflash/languages/registry.py: Converted side-effect imports (from X import Y as _) to importlib.import_module() to avoid ruff F401 auto-removal that would break language registration

Code Review

Optimization is correct - The _find_top_level_arg_node optimization maintains semantic equivalence:

  • Loop termination logic is equivalent (explicit return None vs while condition)
  • None checks on grandparent are properly handled
  • current = parent reuses already-fetched reference (was current = current.parent)
  • Local variable caching of parent.type reduces attribute lookups

No critical issues found in this PR's diff (base branch omni-java → this optimization PR).

Note for base branch: The omni-java branch has several ET.parse() calls across config.py, test_runner.py, init_java.py, coverage_utils.py, and detector.py that could use the existing _safe_parse_xml() helper from build_tools.py for consistency.

Test Coverage

File PR Coverage Main Coverage Delta
codeflash/languages/java/remove_asserts.py 88% N/A (new) -
codeflash/languages/registry.py 79% 78% +1%
Overall 79.3% 78.4% +0.97%
  • Overall coverage increased by ~1% with this PR
  • The optimized file (remove_asserts.py) has 88% coverage
  • 35 test failures on PR branch (vs 8 on main) — the additional 27 failures are from the omni-java base branch (Java comparator JAR not available in CI, not from this PR's changes)

Mypy

72 errors found across Java language support files — these are pre-existing in the omni-java base branch (missing type parameters, abstract class instantiation, unreachable code). Not introduced by this optimization PR.

Optimization PRs

21 open optimization PRs from codeflash-ai[bot]. None were auto-merged because CI is not fully green:

  • All PRs have code/snyk failures (quota limit) and JS E2E test failures (js-cjs-function-optimization, js-esm-async-optimization, js-ts-class-optimization)
  • These JS failures appear across all optimization PRs but not on the base PR codeflash-omni-java #1199
  • Core checks (prek, unit-tests, type-check-cli) pass on most PRs

Last updated: 2026-02-21

…2026-02-21T00.26.34

⚡️ Speed up function `_ensure_languages_registered` by 125% in PR #1628 (`codeflash/optimize-pr1199-2026-02-21T00.13.51`)
@codeflash-ai
Copy link
Contributor Author

codeflash-ai bot commented Feb 21, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants